BNN for Multiclass Classification

Neural Networks
Classification
Deep Learning
Neural networks that incorporate Bayesian inference to predict probabilities across three or more mutually exclusive classes.

General Principles

Building upon the Binary Classification BNN, the BNN Multiclass Classification model can handle dependent variables with K > 2 discrete categories.

Instead of the final layer returning a single output, the final layer in a multiclass BNN returns a K-dimensional vector of scores (logits) for each observation. To transform these continuous scores into valid probabilities that sum to 1 across all K classes, we apply the softmax activation function. Finally, the categorized predictions are evaluated using a Categorical likelihood.

Considerations

Note
  • Output Layer Dimensions: While binary classification network predictions can be compressed to a single output logit per observation, multiclass networks MUST output exactly K dimensions in their final layer, matching the number of target classes.
  • The Softmax Simplex: Applying the softmax function across the final layer’s logits guarantees that the resulting outputs form a probability simplex πŸ›ˆ. This is biologically similar to independent Poisson rates strictly normalizing to fixed categorical ratios.
  • Likelihood Function: After calculating the probabilities with softmax, we use a Categorical distribution as the final likelihood, matching the integer index of the observed category.
  • Improved Calibration: Multiclass BNNs greatly reduce out-of-distribution overconfidence. Standard deep learning cross-entropy models will often assign >99% probability to an unseen class purely due to the exponential nature of softmax. In a BNN, exploring the posterior width of the parameters yields β€œflat” unconfident probability profiles over K classes when the input is outside the training distribution.

Example

Below is an example code snippet demonstrating a Bayesian Neural Network for multiclass classification using the Bayesian Inference (BI) package. This example generates a synthetic K=3 cluster dataset.

Code
from BI import bi
import jax.numpy as jnp
import jax

# Setup device------------------------------------------------
m = bi(platform='cpu')

# Generate Synthetic Data ------------------------------------
# 3 classes based on a random normal distribution split
key = jax.random.PRNGKey(42)
X = jax.random.normal(key, (300, 2))
# Rule: Q1=Class 0, Q2/Q3=Class 1, Q4=Class 2
Y = jnp.where(X[:, 0] > 0, jnp.where(X[:, 1] > 0, 0, 1), 2)

m.data_on_model = dict(X=X, Y=Y)

# Define model ------------------------------------------------
def model(X, Y, D_H1=5, K=3):
    N, D_X = X.shape
    
    # First hidden layer: 2 input features -> 5 hidden units
    w1 = m.bnn.layer_linear(
        X, 
        dist=m.dist.normal(0, 1, name='w1_weight', shape=(D_X, D_H1)),
        activation='tanh'
    )
    
    # Final output layer: 5 hidden units -> K output units
    # Note: No activation is applied automatically inside the layer function here
    w2 = m.bnn.layer_linear(
        w1,
        dist=m.dist.normal(0, 1, name='w2_weight', shape=(D_H1, K))
    )
    
    # Apply Softmax across the K dimension (axis=-1) to yield probabilities
    p = jax.nn.softmax(w2, axis=-1)
    
    # Categorical Likelihood matching indices in Y
    m.dist.categorical(probs=p, obs=Y)

# Run mcmc ------------------------------------------------
m.fit(model) # Approximate posterior distributions

# Predictions from the model ------------------------------------------------
import matplotlib.pyplot as plt

# Create a grid to evaluate the model
n_grid = 50
x0 = jnp.linspace(X[:, 0].min() - 0.5, X[:, 0].max() + 0.5, n_grid)
x1 = jnp.linspace(X[:, 1].min() - 0.5, X[:, 1].max() + 0.5, n_grid)
xx0, xx1 = jnp.meshgrid(x0, x1)
X_grid = jnp.c_[xx0.ravel(), xx1.ravel()]

# Swap data on model temporarily to predict on the grid
m.data_on_model = dict(X=X_grid, Y=jnp.zeros(X_grid.shape[0], dtype=jnp.int32))
pred = m.sample(data = m.data_on_model)['x']
p_mean = jnp.mean(pred, axis=0)

# Plotting the posterior predictive mean (categorical blending)
fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True)
contour = ax.contourf(xx0, xx1, p_mean.reshape(n_grid, n_grid), cmap="viridis", alpha=0.6)
scatter = ax.scatter(X[:, 0], X[:, 1], c=Y, cmap="viridis", edgecolors='k')
ax.set(title="Posterior Predictive Mean", xlabel="Feature 1", ylabel="Feature 2")
fig.colorbar(contour, ax=ax)
/home/sosa/work/3.12venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning:

IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
BI v 0.0.45 package loaded
jax.local_device_count 32
⚠️This function is still in development. Use it with caution. ⚠️
⚠️This function is still in development. Use it with caution. ⚠️
⚠️This function is still in development. Use it with caution. ⚠️
⚠️This function is still in development. Use it with caution. ⚠️
  0%|          | 0/2000 [00:00<?, ?it/s]Compiling.. :   0%|          | 0/2000 [00:00<?, ?it/s]
  0%|          | 0/2000 [00:00<?, ?it/s]
Compiling.. :   0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

Compiling.. :   0%|          | 0/2000 [00:00<?, ?it/s]


  0%|          | 0/2000 [00:00<?, ?it/s]


Compiling.. :   0%|          | 0/2000 [00:00<?, ?it/s]
⚠️This function is still in development. Use it with caution. ⚠️
⚠️This function is still in development. Use it with caution. ⚠️
Running chain 0:   0%|          | 0/2000 [00:00<?, ?it/s]
Running chain 1:   0%|          | 0/2000 [00:00<?, ?it/s]

Running chain 2:   0%|          | 0/2000 [00:00<?, ?it/s]


Running chain 3:   0%|          | 0/2000 [00:00<?, ?it/s]

Running chain 2:   5%|β–Œ         | 100/2000 [00:01<00:06, 313.80it/s]


Running chain 3:   5%|β–Œ         | 100/2000 [00:01<00:06, 295.21it/s]Running chain 0:   5%|β–Œ         | 100/2000 [00:01<00:06, 288.10it/s]
Running chain 1:   5%|β–Œ         | 100/2000 [00:01<00:07, 269.55it/s]


Running chain 3:  10%|β–ˆ         | 200/2000 [00:01<00:06, 257.20it/s]

Running chain 2:  10%|β–ˆ         | 200/2000 [00:01<00:07, 247.69it/s]
Running chain 1:  10%|β–ˆ         | 200/2000 [00:01<00:07, 241.77it/s]Running chain 0:  10%|β–ˆ         | 200/2000 [00:01<00:09, 192.63it/s]


Running chain 3:  15%|β–ˆβ–Œ        | 300/2000 [00:02<00:06, 260.28it/s]
Running chain 1:  15%|β–ˆβ–Œ        | 300/2000 [00:02<00:06, 247.50it/s]

Running chain 2:  15%|β–ˆβ–Œ        | 300/2000 [00:02<00:07, 238.83it/s]Running chain 0:  15%|β–ˆβ–Œ        | 300/2000 [00:02<00:07, 223.03it/s]

Running chain 2:  20%|β–ˆβ–ˆ        | 400/2000 [00:02<00:06, 257.61it/s]


Running chain 3:  20%|β–ˆβ–ˆ        | 400/2000 [00:02<00:06, 249.86it/s]
Running chain 1:  20%|β–ˆβ–ˆ        | 400/2000 [00:02<00:06, 257.96it/s]Running chain 0:  20%|β–ˆβ–ˆ        | 400/2000 [00:02<00:06, 236.89it/s]
Running chain 1:  25%|β–ˆβ–ˆβ–Œ       | 500/2000 [00:02<00:05, 267.50it/s]


Running chain 3:  25%|β–ˆβ–ˆβ–Œ       | 500/2000 [00:02<00:05, 255.83it/s]

Running chain 2:  25%|β–ˆβ–ˆβ–Œ       | 500/2000 [00:03<00:06, 236.19it/s]Running chain 0:  25%|β–ˆβ–ˆβ–Œ       | 500/2000 [00:03<00:06, 246.35it/s]
Running chain 1:  30%|β–ˆβ–ˆβ–ˆ       | 600/2000 [00:03<00:05, 277.59it/s]


Running chain 3:  30%|β–ˆβ–ˆβ–ˆ       | 600/2000 [00:03<00:05, 262.06it/s]

Running chain 2:  30%|β–ˆβ–ˆβ–ˆ       | 600/2000 [00:03<00:05, 252.91it/s]Running chain 0:  30%|β–ˆβ–ˆβ–ˆ       | 600/2000 [00:03<00:05, 267.94it/s]
Running chain 1:  35%|β–ˆβ–ˆβ–ˆβ–Œ      | 700/2000 [00:03<00:04, 295.83it/s]

Running chain 2:  35%|β–ˆβ–ˆβ–ˆβ–Œ      | 700/2000 [00:03<00:04, 284.22it/s]


Running chain 3:  35%|β–ˆβ–ˆβ–ˆβ–Œ      | 700/2000 [00:03<00:04, 267.21it/s]Running chain 0:  35%|β–ˆβ–ˆβ–ˆβ–Œ      | 700/2000 [00:03<00:04, 292.89it/s]
Running chain 1:  40%|β–ˆβ–ˆβ–ˆβ–ˆ      | 800/2000 [00:03<00:03, 315.23it/s]

Running chain 2:  40%|β–ˆβ–ˆβ–ˆβ–ˆ      | 800/2000 [00:03<00:03, 310.60it/s]


Running chain 3:  40%|β–ˆβ–ˆβ–ˆβ–ˆ      | 800/2000 [00:03<00:04, 293.20it/s]Running chain 0:  40%|β–ˆβ–ˆβ–ˆβ–ˆ      | 800/2000 [00:03<00:03, 300.83it/s]
Running chain 1:  45%|β–ˆβ–ˆβ–ˆβ–ˆβ–Œ     | 900/2000 [00:04<00:03, 337.04it/s]Running chain 0:  45%|β–ˆβ–ˆβ–ˆβ–ˆβ–Œ     | 900/2000 [00:04<00:03, 345.78it/s]


Running chain 3:  45%|β–ˆβ–ˆβ–ˆβ–ˆβ–Œ     | 900/2000 [00:04<00:03, 311.54it/s]

Running chain 2:  45%|β–ˆβ–ˆβ–ˆβ–ˆβ–Œ     | 900/2000 [00:04<00:03, 306.63it/s]
Running chain 1:  50%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆ     | 1000/2000 [00:04<00:03, 326.86it/s]Running chain 0:  50%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆ     | 1000/2000 [00:04<00:03, 328.70it/s]


Running chain 3:  50%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆ     | 1000/2000 [00:04<00:03, 291.82it/s]

Running chain 2:  50%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆ     | 1000/2000 [00:04<00:03, 287.96it/s]
Running chain 1:  55%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ    | 1100/2000 [00:04<00:02, 317.29it/s]Running chain 0:  55%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ    | 1100/2000 [00:04<00:02, 308.30it/s]


Running chain 3:  55%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ    | 1100/2000 [00:04<00:03, 289.18it/s]
Running chain 1:  60%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ    | 1200/2000 [00:05<00:02, 317.02it/s]

Running chain 2:  55%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ    | 1100/2000 [00:05<00:03, 263.09it/s]Running chain 0:  60%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ    | 1200/2000 [00:05<00:02, 300.92it/s]


Running chain 3:  60%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ    | 1200/2000 [00:05<00:02, 284.22it/s]
Running chain 1:  65%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ   | 1300/2000 [00:05<00:02, 319.08it/s]

Running chain 2:  60%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ    | 1200/2000 [00:05<00:03, 254.71it/s]Running chain 0:  65%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ   | 1300/2000 [00:05<00:02, 302.78it/s]
Running chain 1:  70%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ   | 1400/2000 [00:05<00:01, 327.62it/s]


Running chain 3:  65%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ   | 1300/2000 [00:05<00:02, 289.51it/s]

Running chain 2:  65%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ   | 1300/2000 [00:05<00:02, 251.21it/s]Running chain 0:  70%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ   | 1400/2000 [00:05<00:02, 299.71it/s]
Running chain 1:  75%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ  | 1500/2000 [00:05<00:01, 331.10it/s]


Running chain 3:  70%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ   | 1400/2000 [00:05<00:02, 294.21it/s]
Running chain 1:  80%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ  | 1600/2000 [00:06<00:01, 324.42it/s]Running chain 0:  75%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ  | 1500/2000 [00:06<00:01, 299.01it/s]


Running chain 3:  75%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ  | 1500/2000 [00:06<00:01, 292.58it/s]

Running chain 2:  70%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ   | 1400/2000 [00:06<00:02, 245.77it/s]
Running chain 1:  85%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ | 1700/2000 [00:06<00:00, 324.22it/s]Running chain 0:  80%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ  | 1600/2000 [00:06<00:01, 286.78it/s]


Running chain 3:  80%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ  | 1600/2000 [00:06<00:01, 297.02it/s]

Running chain 2:  75%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ  | 1500/2000 [00:06<00:02, 240.13it/s]
Running chain 1:  90%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 1800/2000 [00:06<00:00, 322.30it/s]Running chain 0:  85%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ | 1700/2000 [00:06<00:01, 279.75it/s]


Running chain 3:  85%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ | 1700/2000 [00:07<00:01, 284.14it/s]
Running chain 1:  95%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ| 1900/2000 [00:07<00:00, 318.04it/s]

Running chain 2:  80%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ  | 1600/2000 [00:07<00:01, 231.67it/s]


Running chain 3:  90%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 1800/2000 [00:07<00:00, 292.60it/s]Running chain 0:  90%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 1800/2000 [00:07<00:00, 285.08it/s]
Running chain 1: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 2000/2000 [00:07<00:00, 326.52it/s]Running chain 1: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 2000/2000 [00:07<00:00, 267.86it/s]


Running chain 2:  85%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ | 1700/2000 [00:07<00:01, 232.94it/s]


Running chain 3:  95%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ| 1900/2000 [00:07<00:00, 286.45it/s]Running chain 0:  95%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ| 1900/2000 [00:07<00:00, 281.42it/s]


Running chain 3: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 2000/2000 [00:08<00:00, 290.65it/s]Running chain 3: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 2000/2000 [00:08<00:00, 249.41it/s]
Running chain 0: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 2000/2000 [00:08<00:00, 279.21it/s]Running chain 0: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 2000/2000 [00:08<00:00, 248.08it/s]


Running chain 2:  90%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 1800/2000 [00:08<00:00, 234.20it/s]

Running chain 2:  95%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ| 1900/2000 [00:08<00:00, 242.01it/s]

Running chain 2: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 2000/2000 [00:08<00:00, 258.02it/s]Running chain 2: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 2000/2000 [00:08<00:00, 227.81it/s]
/home/sosa/work/BI/BI/Main/main.py:590: UserWarning:

Sample's batch dimension size 4000 is different from the provided 1 num_samples argument. Defaulting to 4000.
⚠️This function is still in development. Use it with caution. ⚠️
⚠️This function is still in development. Use it with caution. ⚠️

using BayesianInference
using PythonCall

# Setup device------------------------------------------------
m = importBI(platform="cpu")

# Generate Synthetic Data ------------------------------------
np = pyimport("numpy")
jax_random = pyimport("jax.random")
jnp = pyimport("jax.numpy")

key = jax_random.PRNGKey(42)
X = jax_random.normal(key, (300, 2))

# Simple rule to partition into K=3 classes
Y = jnp.where(X[:, 0] > 0, jnp.where(X[:, 1] > 0, 0, 1), 2)

m.data_on_model["X"] = X
m.data_on_model["Y"] = Y

# Define model ------------------------------------------------
@BI function model(X, Y)
    N, D_X = size(X)
    D_H1 = 5
    K = 3
    
    # First hidden layer
    w1 = m.bnn.layer_linear(
        X, 
        dist=m.dist.normal(0, 1, name="w1_weight", shape=(D_X, D_H1)),
        activation="tanh"
    )
    
    # Final output layer 
    w2 = m.bnn.layer_linear(
        w1,
        dist=m.dist.normal(0, 1, name="w2_weight", shape=(D_H1, K))
    )
    
    # Softmax conversion to probability simplex
    p = jax.nn.softmax(w2, axis=-1)
    
    # Categorical Likelihood
    m.dist.categorical(probs=p, obs=Y)
end

# Run mcmc ------------------------------------------------
m.fit(model, num_samples=500, progress_bar=false)

Mathematical Details

Bayesian Formulation

For a multiclass classification task spanning N observations and K mutually exclusive classes, we model the probability vector \theta_i that the response Y_i \in \{0, 1, ..., K-1\} falls into each respective class.

Using a single hidden layer with a hyperbolic tangent (\tanh) activation function, the model is structured as:

\begin{aligned} Y_i &\sim \text{Categorical}(\theta_i) \\ \theta_i & = \text{Softmax}(\phi_i) \\ \phi_i & = H_i \Theta_2 \\ H_i & = \tanh(X_i \Theta_1) \\ \Theta_1 & \sim \text{Normal}(0, 1) \\ \Theta_2 & \sim \text{Normal}(0, 1) \\ \end{aligned}

where:

  • Y_i is the observed class index for the i-th observation (Y_i \in \{0, 1, ..., K-1\}).
  • \theta_i is the predicted probability vector for the i-th observation.
  • \phi_i are the K-dimensional logits.
  • X_i is the input row vector for the i-th observation, with features length D_X = 2.
  • H_i is the hidden layer representation vector for the i-th observation. It has length D_H = 5.
  • \Theta_1 is the weight matrix of the first hidden layer (2 \times 5).
  • \Theta_2 is the final layer weight matrix mapping the hidden features to the logits for the K=3 classes (5 \times 3).
  • All elements within the weight matrices \Theta_1 and \Theta_2 are assigned independent standard Normal priors.

Notes

Note
  • For large outputs where K > 100, computing the exact softmax normalization scalar (the denominator term combining all exponentiated logits) can become computationally expensive over thousands of MCMC posterior evaluations.
  • Neural networks configured with a standard Cross-Entropy loss mapping to one-hot vectors conceptually perform exactly this sequence: dot product of final weights \rightarrow Softmax \rightarrow Categorical Likelihood.

Reference(s)

  1. PyData Berlin 2025: Introduction to Stochastic Variational Inference with NumPyro